# -*- coding: UTF-8 -*-  
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
tf.compat.v1.disable_eager_execution()
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.keras import layers, Model, optimizers

from agents.dqn import DQN
from agents.dqn_gpi import DQN_GPI
from agents.sfdqn import SFDQN
from agents.buffer import ReplayBuffer
from features.deep import DeepSF
from tasks.reacher import Reacher
from utils.config import parse_config_file

import copy
import argparse

# parse arguments from command line
parser = argparse.ArgumentParser(description='Succesor Feature Deep Q-learning')
parser.add_argument('--test_task_idx', default=11, type=int, help='Traget task used for transfer')
parser.add_argument('--gamma', default=0.8, type=float, help='discount factor to be used')
parser.add_argument('--include_target_dqn', default="False", type=str, help='whether or not to use UVFA for DQN')
parser.add_argument('--train_task_idxs', default='0,4,8,9', type=str, help='list of (four) custom train task indeces to train on before transfering to test task' )


args = parser.parse_args()

# ablation parameters
test_task_idx = args.test_task_idx
gamma = args.gamma
# convert include_target_dqn to bool
include_target_dqn = args.include_target_dqn
if include_target_dqn == "True":
    include_target_dqn = True
elif include_target_dqn == "False":
    include_target_dqn = False
else:
    raise ValueError(f'Unsupported input {include_target_dqn} for include_target_dqn (only "True"/"False" allowed)')
# create training task list
train_task_idxs = [int(i) for i in args.train_task_idxs.strip().split(',')]
if len(train_task_idxs) != 4:
    raise ValueError(f'Unsupported number of training tasks {len(train_task_idxs)}')
# creat string for naming purposes
train_tasks_str = '-'.join([str(i) for i in train_task_idxs])
# add target tasks to the training tasks (this is needed to enable GPI)
train_task_idxs = train_task_idxs + [test_task_idx]
# train_task_idxs = list(range(len(goals))) + [test_task_idx] #[4, 8, 9] + [test_task_idx]#list(range(len(goals))) + [test_task_idx]


print('\n==========Importnant Experiment Parameters===========')
print(f"test_task_idx: {test_task_idx}")
print(f"gamma: {gamma}")
print(f"include_target_dqn: {include_target_dqn}")
print('=====================================================\n')

# read parameters from config file
config_params = parse_config_file('reacher.cfg')

# set gamma (for gamma ablation)
config_params['AGENT']['gamma'] = gamma

gen_params = config_params['GENERAL']
n_samples = gen_params['n_samples']

task_params = config_params['TASK']
goals = task_params['train_targets']
test_goals = task_params['test_targets']
all_goals = goals + test_goals
    
agent_params = config_params['AGENT']
dqn_params = config_params['DQN']
sfdqn_params = config_params['SFDQN']

# SFDQN agent without GPI
sfdqn_wo_gpi_params = copy.deepcopy(config_params['SFDQN'])
sfdqn_wo_gpi_params['use_gpi'] =  False



# tasks
def generate_tasks(include_target):
    # task to be transfered (the used for testing)
    # total set of tasks used for GPI training/testing
    # a custom set of train tsks (e.g. that lie in a small neighborhood of the env state space)
    train_tasks = [Reacher(all_goals, i, include_target) for i in train_task_idxs]
    test_tasks = [Reacher(all_goals, i, include_target) for i in train_task_idxs]
    return train_tasks, test_tasks


# keras model
def dqn_model_lambda(include_target):
    keras_params = dqn_params['keras_params']
    # ADDED flexibility for DQN to be without goal locations as input features
    if include_target:
        x = y = layers.Input(6)
    else:
        x = y = layers.Input(4)
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9, activation='linear')(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(sgd, 'mse')
    return model


# keras model for the SF
def sf_model_lambda(x):
    n_features = len(all_goals)
    keras_params = sfdqn_params['keras_params']
    y = x
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * n_features, activation='linear')(y)
    y = layers.Reshape((9, n_features))(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(sgd, 'mse')
    return model


def train():
    
    # # build SFDQN    
    # print('building SFDQN')
    # deep_sf = DeepSF(keras_model_handle=sf_model_lambda, **sfdqn_params)
    # sfdqn = SFDQN(deep_sf=deep_sf, buffer=ReplayBuffer(sfdqn_params['buffer_params']),
    #               **sfdqn_params, **agent_params)
    
    # # train SFDQN
    # print('training SFDQN')
    # train_tasks, test_tasks = generate_tasks(False)
    # sfdqn_perf = sfdqn.train(train_tasks, n_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'], test_task_idx=-1)

    # # build SFDQN (wo GPI enabled)    
    # print('building SFDQN (wo GPI enabled)')
    # deep_sf_wo_gpi = DeepSF(keras_model_handle=sf_model_lambda, **sfdqn_wo_gpi_params)
    # sfdqn_wo_gpi = SFDQN(deep_sf=deep_sf_wo_gpi, buffer=ReplayBuffer(sfdqn_wo_gpi_params['buffer_params']),
    #               **sfdqn_wo_gpi_params, **agent_params)
    
    # # train SFDQN
    # print('training SFDQN (wo GPI enabled)')
    # train_tasks, test_tasks = generate_tasks(False)
    # sfdqn_wo_gpi_perf = sfdqn_wo_gpi.train(train_tasks, n_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'], test_task_idx=-1)

    # # build DQN with GPI
    # print('building DQN with GPI')
    # dqn_gpi = DQN_GPI(model_lambda=dqn_model_lambda, buffer=ReplayBuffer(dqn_params['buffer_params']), include_target=include_target_dqn,
    #           **dqn_params, **agent_params)
    
    # # training DQN
    # print('training DQN with GPI')
    # train_tasks, test_tasks = generate_tasks(include_target_dqn)
    # dqn_gpi_perf = dqn_gpi.train(train_tasks, n_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'], test_task_idx=-1)

    # # smooth data    
    # def smooth(y, box_pts):
    #     return np.convolve(y, np.ones(box_pts) / box_pts, mode='same')

    # sfdqn_perf = smooth(sfdqn_perf, 10)[:-5]
    # # dqn_perf = smooth(dqn_perf, 10)[:-5]
    # dqn_gpi_perf = smooth(dqn_gpi_perf, 10)[:-5]
    # sfdqn_wo_gpi_perf = smooth(sfdqn_wo_gpi_perf, 10)[:-5]

    # # save
    # np.save(f'sfdqn_perf_tr-{train_tasks_str}_te-{test_task_idx}_gamma-{gamma}', sfdqn_perf)
    # # np.save(f'dqn_perf_{test_task_idx}_gamma-{gamma}_include_target-{include_target_dqn}', dqn_perf)
    # np.save(f'dqn_perf_tr-{train_tasks_str}_te-{test_task_idx}_gamma-{gamma}_include_target_dqn-{include_target_dqn}', dqn_gpi_perf)
    # np.save(f'sfdqn_wo_gpi_perf_tr-{train_tasks_str}_te-{test_task_idx}_gamma-{gamma}', sfdqn_wo_gpi_perf)

    #load
    sfdqn_perf = np.load(f'sfdqn_perf_tr-{train_tasks_str}_te-{test_task_idx}_gamma-{gamma}.npy')
    dqn_gpi_perf = np.load(f'dqn_perf_tr-{train_tasks_str}_te-{test_task_idx}_gamma-{gamma}_include_target_dqn-{include_target_dqn}.npy')
    # sfdqn_wo_gpi_perf = np.load(f'sfdqn_wo_gpi_perf_tr-{train_tasks_str}_te-{test_task_idx}_gamma-{gamma}.npy')

    x = np.linspace(0, 5, sfdqn_perf.size)

    
    # reporting progress
    ticksize = 20
    textsize = 25
    plt.rc('font', size=textsize)  # controls default text sizes
    plt.rc('axes', titlesize=textsize)  # fontsize of the axes title
    plt.rc('axes', labelsize=textsize)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('legend', fontsize=ticksize)  # legend fontsize

    plt.figure(figsize=(8, 6))
    ax = plt.gca()
    ax.plot(x, sfdqn_perf, color='#386cb0', linewidth=2, label='SFDQN (GPI)')
    # ax.plot(x, sfdqn_wo_gpi_perf, label='SFDQN_wo_GPI')
    ax.plot(x, dqn_gpi_perf, color='#33a02c', linewidth=2, label='DQN (GPI)')
    # ax.plot(x, dqn_perf, label='DQN')
    ax.set_xticks([1, 2, 3, 4, 5], labels=[f'Src{train_task_idxs[0]+1}', f'Src{train_task_idxs[1]+1}', f'Src{train_task_idxs[2]+1}', f'Src{train_task_idxs[3]+1}', f'Trg{test_task_idx+1}'])
    if test_task_idx==1 or test_task_idx==6:
        plt.ylabel('test task reward')
    if test_task_idx==6 or test_task_idx==7 or test_task_idx==10 or test_task_idx==11:
        plt.xlabel('training task')
    # plt.title(f'Transfer performance comaprison for tr-{train_tasks_str} te-{test_task_idx}('+r'$\gamma$'+f'={gamma})')
    plt.tight_layout()
    if test_task_idx==5:
        plt.legend(frameon=False)
    plt.savefig(f'figures/paper/sfdqn_dqn_gpi_return_new_task_tr-{train_tasks_str}_te-{test_task_idx}_gamma-{gamma}_include_target_dqn-{include_target_dqn}_paper.pdf', format="pdf", bbox_inches="tight")


train()
